import torch
import os
import numpy as np
import cv2
import matplotlib.pyplot as plt
import torchvision
from PIL import Image
from xml.dom.minidom import parse
import utils
import transforms as T
from engine import train_one_epoch, evaluate
import xml.etree.cElementTree as ET
import collections
import pandas as pd
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
import sys
import random

#%%
# 制作数据集的函数，这里继承一个类，然后用这个类生成一个对象，传入root目录和一定的参数
class MarkDataset(torch.utils.data.Dataset):
    def __init__(self, root, transforms=None, type='train', label_list=None, num_class=2):
        # 初始化对象，这里包括数据集的一些信息
        # root 根目录路径
        # transform 设置自己的transform函数
        # img 图片文件名的列表
        # bbox_xml 标注文件名的列表
        self.type = type  # 训练或者测试集的种类
        self.num_class = num_class
        if label_list is None:
            label_list = []
        self.label_list = label_list
        self.root = root
        self.transforms = transforms
        # load all image files, sorting them to ensure that they are aligned
        if self.type == 'train':
            self.imgs = list(os.listdir(os.path.join(root, "train", "JPEGImages")))
            self.bbox_xml = list(os.listdir(os.path.join(root, "train", "Annotations")))
        elif self.type == 'val' or self.type == 'validation':
            self.imgs = list(os.listdir(os.path.join(root, 'validation', "JPEGImages")))
            self.bbox_xml = list(os.listdir(os.path.join(root, 'validation', "Annotations")))
        else:
            pass

    def __getitem__(self, idx):
        # load images and bbox
        # img_path 每一个对象的图像路径
        # bbox_xml_path 每一个标注文件的图像路径
        if self.type == 'train':
            img_path = os.path.join(self.root, "train", "JPEGImages", self.imgs[idx])
            bbox_xml_path = os.path.join(self.root, "train", "Annotations", self.bbox_xml[idx])
        elif self.type == 'val' or self.type == 'validation':
            img_path = os.path.join(self.root, "validation", "JPEGImages", self.imgs[idx])
            bbox_xml_path = os.path.join(self.root, "validation", "Annotations", self.bbox_xml[idx])
        else:
            pass

        # 打开当前一张图片
        # 用PIL打开图片，是因为pytorch可以直接将PIL图片转换为tensor
        img = Image.open(img_path).convert("RGB")
        # 解析xml
        dom = parse(bbox_xml_path)
        data = dom.documentElement
        objects = data.getElementsByTagName('object')
        boxes = []
        labels = []
        for object_ in objects:
            # name就是label字符串
            name = object_.getElementsByTagName('name')[0].childNodes[0].nodeValue  # 就是label，mark_type_1或mark_type_2
            # 这里直接将list 反查 name 的索引，作为label添加上去
            labels.append(self.label_list.index(name))
            # 返回的应该是一个列表，但是这里只有一个bndbox，但是仍然要用下标0来获得第一个的对象
            bndbox = object_.getElementsByTagName('bndbox')[0]
            xmin = np.float(bndbox.getElementsByTagName('xmin')[0].childNodes[0].nodeValue)
            ymin = np.float(bndbox.getElementsByTagName('ymin')[0].childNodes[0].nodeValue)
            xmax = np.float(bndbox.getElementsByTagName('xmax')[0].childNodes[0].nodeValue)
            ymax = np.float(bndbox.getElementsByTagName('ymax')[0].childNodes[0].nodeValue)
            # 列表汇总，附加一个四个数的列表
            boxes.append([xmin, ymin, xmax, ymax])

        # 例程中要求，将target中所有的东西转换为tensor
        # 转换为tensor
        boxes = torch.as_tensor(boxes, dtype=torch.float32)  # boxes的dtype必须是flaot
        labels = torch.as_tensor(labels, dtype=torch.int64)  # labels的dtype必须是int64

        # 下面的部分是为了evaluate做准备
        image_id = torch.tensor([idx])
        area = (boxes[:, 3] - boxes[:, 1]) * (boxes[:, 2] - boxes[:, 0])
        # suppose all instances are not crowd
        # 这里应该是，默认所有的东西都没有遮挡、但事实上是，这里可能需要修改
        iscrowd = torch.zeros((len(objects),), dtype=torch.int64)

        # 制作数据集的target部分，这个部分包括以下几样东西
        # target是一个字典，里面至少有boxes和labels关键字，这里，剩下的三个是为了evaluate
        target = {"boxes": boxes,
                  "labels": labels,
                  "image_id": image_id,
                  "area": area,
                  "iscrowd": iscrowd}

        # 使用自定义的transform函数进行数据增强，和转换为tensor
        if self.transforms is not None:
            img, target = self.transforms(img, target)

        # 返回从数据集里面取一个东西，返回这张图片本身tensor和目标tensor
        return img, target

    def __len__(self):
        return len(self.imgs)


# 数据增强，其实也是数据处理的必须步骤，至少需要一个ToTensor变换
def get_transform(train):
    # 猜测，是用append的形式 列表的形式，依次完成各种transform
    transforms_list = []  # 函数的列表，用来依次装载各种变换
    # 一开始就加上了一个to_tensor
    transforms_list.append(T.ToTensor())

    if train:
        # during training, randomly flip the training images
        # and ground-truth for data augmentation
        # 50%的概率水平翻转
        # 暂时不用水平翻转
        transforms_list.append(T.RandomHorizontalFlip(0.1))  # 随机水平翻转
        transforms_list.append(T.RandomCutout(n_holes=1, length_w=100, length_h=10, prob=0.15))  # 随机遮挡
        if random.random() < 0.5:
            transforms_list.append(T.RandomPadding(prob=0.3, mean=0, threshold=0.75))  # 随机缩小
        else:
            transforms_list.append(T.RandomCrop(prob=0.3, threshold=1.5))  # 随机放大
    else:
        pass

    # 最后用Compose吧这个列表组合一下，成为可以torch可以读取的形式
    return T.Compose(transforms_list)


def main(args):
    print(args, '\n')
    # exit(-2)

#%%
    root = r'dataset'  # 数据集根目录
    save_root = os.path.join(root, 'models')
    # num_epochs = 41  # 训练回合数
    num_epochs = args.epochs
    batch_size = args.batch_size
    save_frq = args.save_freq
    print_freq = args.print_freq
    num_workers = args.workers
    # print(root)
    # 解析label_list文件
    with open(os.path.join(root, "label_list.txt"), 'r') as file:
        label_list = file.readlines()
    # map(str.rstrip, label_list)  # 去掉末尾的\n  # map中，传进去一个函数，而不是传进去一个函数的返回值
    label_list = [label.rstrip() for label in label_list]  # 去掉空字符
    label_list = [label for label in label_list if label != '']  # 去掉空行

    # 自动设置类别数量
    num_classes = len(label_list)
    print(label_list)
    print(num_classes)

#%%
    # 设置device
    device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
    # 获取dataset，制作datasets的详细部分在前面类的定义
    dataset = MarkDataset(root, get_transform(train=True), type='train', label_list=label_list, num_class=num_classes)
    dataset_test = MarkDataset(root, get_transform(train=False), type='val', label_list=label_list, num_class=num_classes)

    # 设置dataloder
    data_loader = torch.utils.data.DataLoader(
        dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers,
        collate_fn=utils.collate_fn)

    data_loader_test = torch.utils.data.DataLoader(
        dataset_test, batch_size=1, shuffle=False, num_workers=1,
        collate_fn=utils.collate_fn)

    # 设置模型model
    model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=True, progress=True)  # 或get_object_detection_model(num_classes)
    in_features = model.roi_heads.box_predictor.cls_score.in_features
    model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)
    # move model to the right device
    model.to(device)

    # 设置优化器，针对模型中可以训练的部分，将这些参数作为优化器优化的目的参数
    params = [p for p in model.parameters() if p.requires_grad]  # 取出可训练部分的参数
    # print(params)
    # SGD
    optimizer = torch.optim.SGD(params, lr=0.001, momentum=0.9, weight_decay=0.0005)  # 设置优化器，训练这些参数

    # learning rate调度器，learning rate可以根据需要，在训练的不同阶段进行调度
    # cos学习率
    lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=1, T_mult=2)

    # let's train it for epochs
    for epoch in range(num_epochs):
        # engine.py的train_one_epoch函数将images和targets都.to(device)了
        # 这里直接调用torchvision写好的接口
        # 这个接口需要一些东西：
        # datasets 数据集
        # optimizer 优化器
        # dataloder 数据加载器
        # decive 设备
        # epoch 训练次数
        # print_freq 打印信息频率
        train_one_epoch(model, optimizer, data_loader, device, epoch, print_freq=print_freq)

        # update the learning rate
        # lr调度器有关操作
        lr_scheduler.step()

        # evaluate on the test dataset
        # 每个epoch都在测试集上验证一下
        coco_evaluator = evaluate(model, data_loader_test, device=device)

        # 看一下cocotools验证，输出了什么
        evaluate_items = coco_evaluator.coco_eval.items()
        # print(evaluate_items)
        # for iou_type, coco_eval in self.coco_eval.items():

        # 间歇保存模型
        if 0 == epoch % save_frq:
            save_name = str(epoch) + '.pkl'
            save_path = os.path.join(save_root, save_name)
            if not os.path.exists(save_root):
                os.mkdir(save_root)
            torch.save(model, save_path)

        print('')
        print('==================================================')
        print('')

    print("That's it!")

    # 保存模型
    if not os.path.exists(save_root):
        os.mkdir(save_root)
    save_name = 'last_model.pkl'
    save_path = os.path.join(save_root, save_name)
    torch.save(model, save_path)


if __name__ == '__main__':
    import argparse
    parser = argparse.ArgumentParser(description=__doc__)

    # 必要参数
    parser.add_argument('--epochs', default=26, type=int, metavar='N',
                        help='number of total epochs to run', required=True)  # 训练次数

    # 可选参数
    parser.add_argument('--print-freq', default=10, type=int, help='print frequency')  # 打印频率
    parser.add_argument('--save-freq', default=1, type=int, help='model save frequency')  # 模型保存频率
    parser.add_argument('-b', '--batch-size', default=1, type=int,
                        help='images per gpu, the total batch size is $NGPU x batch_size')  # batch_size
    parser.add_argument('-j', '--workers', default=12, type=int, metavar='N',
                        help='number of data loading workers (default: 4)')  # dataloader 线程数
    parser.add_argument('-r', '--dataset-root', default=r'dataset', help='dataset root path')  # 数据集位置


    # parser.add_argument('--dataset', default='coco', help='dataset')
    # parser.add_argument('--model', default='maskrcnn_resnet50_fpn', help='model')
    # parser.add_argument('--device', default='cuda', help='device')



    # parser.add_argument('--lr', default=0.02, type=float,
    #                     help='initial learning rate, 0.02 is the default value for training '
    #                     'on 8 gpus and 2 images_per_gpu')
    # parser.add_argument('--momentum', default=0.9, type=float, metavar='M',
    #                     help='momentum')
    # parser.add_argument('--wd', '--weight-decay', default=1e-4, type=float,
    #                     metavar='W', help='weight decay (default: 1e-4)',
    #                     dest='weight_decay')
    # parser.add_argument('--lr-step-size', default=8, type=int, help='decrease lr every step-size epochs')
    # parser.add_argument('--lr-steps', default=[16, 22], nargs='+', type=int, help='decrease lr every step-size epochs')
    # parser.add_argument('--lr-gamma', default=0.1, type=float, help='decrease lr by a factor of lr-gamma')

    # parser.add_argument('--output-dir', default='.', help='path where to save')
    # parser.add_argument('--resume', default='', help='resume from checkpoint')
    # parser.add_argument('--start_epoch', default=0, type=int, help='start epoch')
    # parser.add_argument('--aspect-ratio-group-factor', default=3, type=int)
    # parser.add_argument(
    #     "--test-only",
    #     dest="test_only",
    #     help="Only test the model",
    #     action="store_true",
    # )
    # parser.add_argument(
    #     "--pretrained",
    #     dest="pretrained",
    #     help="Use pre-trained models from the modelzoo",
    #     action="store_true",
    # )

    # distributed training parameters
    # parser.add_argument('--world-size', default=1, type=int,
    #                     help='number of distributed processes')
    # parser.add_argument('--dist-url', default='env://', help='url used to set up distributed training')

    args = parser.parse_args()

    # if args.output_dir:
    #     utils.mkdir(args.output_dir)

    # 将这些命令行参数传入主函数中运行
    main(args)
